# Install all libraries
# CoLab has already preinstalled Pytorch for you
! pip install pytorch-lightning wandb rdkit ogb deepchem
# install PyG
import torch
= torch.__version__
VERSION ! pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-{VERSION}.html
! pip install torch-geometric
7 Week 3 tutorial 2 - AI 4 Chemistry
Table of content
- Relevant packages
- Inductive biases
- Graph neural network in chemistry
0. Relevant packages
Pytorch Geometric (PyG)
PyG is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data. You can also browse its documentation for additional details.
Set a random seed to ensure repeatability of experiments
import random
import numpy as np
import torch
# Random Seeds and Reproducibility
0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0) random.seed(
One of the promises of deep learning algorithms is that they can learn to automatically extract features from the raw data
.
However, so far we have used the same featurization methods as we used for the more basic models.
Can our models directly take a molecule as input?
1. Inductive biases
Inductive biases are assumptions we make about the data, that help our models extract signal from it. These assumptions are encoded in the model’s architecture.
For instance, when we (humans) look at images, we think differently than when we read a book, or than when we analyze a molecule. Processing all these different types of data requires different ways of interpretation, and thus different assumptions about the data.
When building models, we attempt to encode these inductive biases in our model’s architecture so they know how to read and process the data.
A natural way of representing molecules is as graphs. A graph is a collection of nodes (atoms) and edges (bonds).
In the end, this is what we assume from the data:
Molecules are formed by atoms connected by bonds, and each atom is influenced mostly by its closest neighbors.
Molecular properties are determined solely by the molecular graph.
This is what we assume and thus what we tell our model. The specific details of how to calculate the solubility of a molecule (or any other property), that’s exactly what the model will try to learn from the data!
2. Graph neural network in chemistry
2.1 Graph representation
In graph theory, a graph \(G=(V,E)\) is defined by a set of nodes (also called vertices) \(V\) and a set of edges (also called links) \(E\) between these vertices. More specifically:
- \(V = \{ v_1, \: ..., \: v_n \}\), a set of nodes;
- \(E \subseteq \{ (i,j) \: | \: i,j \in V, \: i \neq j \}\), a set of edges representing connections between nodes.
If the edges of a graph have directions, the graph is called a directed graph, otherwise it is called an undirected graph.

In many cases we also have attribute or feature information associated with a graph: - node features: \(\mathbf{X} = [..., \: x_i, \: ...]^T \in \mathbb{R}^{|V| \times m}\), and \(x_i \in \mathbb{R}^m\) denotes the feature of node \(i\); - edge features: \(\mathbf{L} = [..., \: l_{i,j}, \: ...]^T \in \mathbb{R}^{|E| \times r}\), and \(l_{i,j} \in \mathbb{R}^r\) denotes the feature of the edge between node \(i\) and node \(j\); - graph features: \(\mathbf{G} = (..., \: g_i, \: ...) \in \mathbb{R}^s\), and \(g_i\) is the feature (or label) \(i\) of the graph, which is usually the prediction target.
For instance, let’s look at the following undirected graph with node features:
This graph has 4 nodes and 4 edges. The nodes are \(V=\{1,2,3,4\}\), and edges \(E=\{(1,2), (2,3), (2,4), (3,4)\}\). Note that for simplicity, we don’t add mirrored pairs like \((2,1)\). And we have the following node features:
\[ \mathbf{X} = \begin{bmatrix} 0 & 1 & 2\\ 1 & 0 & 1\\ 1 & 1 & 0\\ 3 & 1 & 4 \end{bmatrix} \]
The adjacency matrix \(A\) is a square matrix whose elements indicate whether pairs of nodes are adjacent, i.e. connected, or not. In the simplest case, \(A_{ij}\) is 1 if there is a connection from node \(i\) to \(j\), and otherwise 0. For an undirected graph, keep in mind that \(A\) is a symmetric matrix (\(A_{ij}=A_{ji}\)). For the example graph above, we have the following adjacency matrix:
\[ A = \begin{bmatrix} 0 & 1 & 0 & 0\\ 1 & 0 & 1 & 1\\ 0 & 1 & 0 & 1\\ 0 & 1 & 1 & 0 \end{bmatrix} \]
Molecular graph
A molecular graph is a labeled graph whose nodes correspond to the atoms of the compound and edges correspond to chemical bonds. It also has node features (atom features), edge features (bond features) and graph labels (chemical properties of a molecule). Next, we demonstrate a simple example of building a molecular graph (undirected). In this example, we do not consider hydrogen atoms as nodes.
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
= True # < use SVGs instead of PNGs
IPythonConsole.ipython_useSVG = True # adding indices for atoms
IPythonConsole.drawOptions.addAtomIndices = False # not adding indices for bonds
IPythonConsole.drawOptions.addBondIndices = 200, 200
IPythonConsole.molSize
# N,N-dimethylformamide (DMF)
= 'CN(C)C=O'
dmf_smiles = MolFromSmiles(dmf_smiles)
mol # show molecular graph of DMF, atom indices = node indices
mol
Atom features
feature | description |
---|---|
atom type | atomic number |
degree | number of directly-bonded neighbor atoms, including H atoms |
formal charge | integer electronic charge assigned to atom |
hybridization | sp, sp2, sp3, sp3d, or sp3d2 |
= {
ATOM_FEATURES 'atom_type' : [1, 6, 7, 8, 9], # elements: H, C, N, O, F
'degree' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
'formal_charge' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
'hybridization' : [
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
],
}
def get_atom_fv(atom):
"""
Converts rdkit atom object to feature list of indices
:param atom: rdkit atom object
:return: list
"""
= [
atom_fv 'atom_type'].index(atom.GetAtomicNum()),
ATOM_FEATURES['degree'].index(atom.GetTotalDegree()),
ATOM_FEATURES['formal_charge'].index(atom.GetFormalCharge()),
ATOM_FEATURES['hybridization'].index(str(atom.GetHybridization())),
ATOM_FEATURES[
]return atom_fv
= [get_atom_fv(atom) for atom in mol.GetAtoms()]
atom_fvs atom_fvs
Bond features
feature | description |
---|---|
bond type | single, double, triple, or aromatic |
stereo | none, any, E/Z or cis/trans |
conjugated | whether the bond is conjugated |
# Show indices of bonds
= False # not adding indices for atoms
IPythonConsole.drawOptions.addAtomIndices = True # adding indices for bonds
IPythonConsole.drawOptions.addBondIndices mol
= {
BOND_FEATURES 'bond_type' : [
'SINGLE',
'DOUBLE',
'TRIPLE',
'AROMATIC',
'misc'
],'stereo': [
'STEREONONE',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
'STEREOANY',
], 'conjugated': [False, True],
}
def get_bond_fv(bond):
"""
Converts rdkit bond object to feature list of indices
:param bond: rdkit bond object
:return: list
"""
= [
bond_fv 'bond_type'].index(str(bond.GetBondType())),
BOND_FEATURES['stereo'].index(str(bond.GetStereo())),
BOND_FEATURES['conjugated'].index(bond.GetIsConjugated()),
BOND_FEATURES[
]return bond_fv
= [get_bond_fv(bond) for bond in mol.GetBonds()]
bond_fvs bond_fvs
Edge index
In many cases, a list of paired node indices are used to describe edges rather than adjacency matrix. Here we use paired node indices (edge_index
) with shape (2, num_edges) to define the edges in a graph.
\[ \mathbf{E} = \begin{bmatrix} ..., & i, & ..., & j, & ... \\ ..., & j, & ..., & i, & ... \end{bmatrix} \] Like, there has an edge between node \(i\) and node \(j\) (undirected graph).
= [], []
edge_index0, edge_index1
for bond in mol.GetBonds():
= bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
i, j += [i, j]
edge_index0 += [j, i]
edge_index1
= [edge_index0, edge_index1]
edge_index edge_index
Molecular graph data
We set the density of DMF(0.944 \(g/cm^3\)) as the graph feature (label). Here we use Data class in PyG
to create a graph data for DMF.
import torch
from torch_geometric.data import Data
# convert our data to tensors, which are used for model training
= torch.tensor(atom_fvs, dtype=torch.float)
x = torch.tensor(edge_index, dtype=torch.long)
edge_index = torch.tensor(bond_fvs, dtype=torch.float)
edge_attr = torch.tensor([0.944], dtype=torch.float)
y
= Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
dmf_data dmf_data
2.2 Graph Neural Network
A graph neural network (GNN) is a class of artificial neural networks for processing data that can be represented as graphs. GNNs rely on message passing methods, which means that nodes exchange information with the neighbors, and send “messages” to each other. Generally, GNNs operate in two phases: a message passing phase, which transmits information across the molecule to build a neural representation of the molecule, and a readout phase, which uses the final representation of the molecule to make predictions about the properties of interest.
Message passing
Before looking at the math, we can try to visually understand how message passing works. The first step is that each node creates a feature vector
that represents the message
it wants to send to all its neighbors. In the second step, the messages are sent to the neighbors, so that a node receives one message per adjacent node. As shown in the figure below, after a message passing step, node 1
can get the message from node 2
, and node 2
can get messages from node 1
, node 3
and node 4
. The third step is that each node will aggregate all messages from neighbors and get a message vector
. Then, the fourth step is that each node updates its feature vector
based on its message vector
and previous feature vector
.
Moreover, with the iteration of message passing, each node can obtain the feature vectors of more distant nodes and not limited to neighbors. As shown in the figure below, node A
can get informations from node E
and node F
in the interation 2, which are not the neighbors of node A
. Node C
, the neighbor of node A
, can obtain the information of nodes E
and F
in the iteration 1, so node A
can obtain the information of nodes E
and F
in the iteration 2.
You can also read the mathematical formulas to better understand the process of message passing.
- Initialization
Get initial hidden feature vector \(h_i^0\) of node \(i\) from its original node features \(x_i\) \[ h_i^0 = I(x_i), \quad \forall i \in V \tag{1} \] where \(I\) is initialize function
Send message \[ m_{j \rightarrow i}^{t+1} = M(h_i^t, \: h_j^t, \: e_{i,j}) \tag{2} \] where \(m_{j \rightarrow i}^{t+1}\) is the message sent from node \(j\) to \(i\) at the \(t+1\) iteration, \(M\) is message function, and \(e_{i,j}\) is the feature of edge between node \(i\) and \(j\)
Message aggregation \[ m_i^{t+1} = \sum_{j \in N(i)} m_{j \rightarrow i}^{t+1} \tag{3} \] where \(N(i)\) presents all neighbor nodes of node \(i\), and \(m_i^{t+1}\) is the aggregated message of node \(i\) at the \(t+1\) iteration
Node update \[ h_i^{t+1} = U(h_i^t, \: m_i^{t+1}) \tag{4} \] where \(h_i^t\) is the hidden feature vector of node \(i\) at the \(t\) iteration, and \(U\) is the update function
Readout
The readout layers will aggregate the hidden feature vectors of all nodes and get graph-level vectors (that is, the properties we want to predict).
\[ \hat{y} = R(\{ h_i^T \: | \: i \in V\}) \tag{5} \] where \(h_i^T\) is the final hidden feature vector of node \(i\), \(\: \: \hat{y}\) is graph-level vectors (our prediction target), and \(R\) is the readout function
Note that in GNNs, the \(I\), \(M\), \(U\) and \(R\) functions need to be differentiable, such as multi-layer artificial neural networks.
Code example
Here, we will define a GNN model using message passing neural network (MPNN) according to paper “Neural Message Passing for Quantum Chemistry”. We just use NNConv class to create message passing layers of our models. The torch_geometric.nn module of PyG contains many different types of layers for message passing and readout, which can help us define GNN models more conveniently.
import torch
import torch.nn.functional as F
from torch.nn import GRU
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, MLP, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
class MPNN(pl.LightningModule):
def __init__(self, hidden_dim, out_dim,
train_data, valid_data, test_data,=32, lr=1e-3):
std, batch_sizesuper().__init__()
self.std = std # std of data's target
self.train_data = train_data
self.valid_data = valid_data
self.test_data = test_data
self.batch_size = batch_size
self.lr = lr
# Initial layers
self.atom_emb = AtomEncoder(emb_dim=hidden_dim)
self.bond_emb = BondEncoder(emb_dim=hidden_dim)
# Message passing layers
= MLP([hidden_dim, hidden_dim*2, hidden_dim*hidden_dim])
nn self.conv = NNConv(hidden_dim, hidden_dim, nn, aggr='mean')
self.gru = GRU(hidden_dim, hidden_dim)
# Readout layers
self.mlp = MLP([hidden_dim, int(hidden_dim/2), out_dim])
def forward(self, data, mode="train"):
# Initialization
= self.atom_emb(data.x)
x = x.unsqueeze(0)
h = self.bond_emb(data.edge_attr)
edge_attr
# Message passing
for i in range(3):
= F.relu(self.conv(x, data.edge_index, edge_attr)) # send message and aggregation
m = self.gru(m.unsqueeze(0), h) # node update
x, h = x.squeeze(0)
x
# Readout
= global_add_pool(x, data.batch)
x = self.mlp(x)
x
return x.view(-1)
def training_step(self, batch, batch_idx):
# Here we define the train loop.
= self.forward(batch, mode="train")
out = F.mse_loss(out, batch.y)
loss self.log("Train loss", loss)
return loss
def validation_step(self, batch, batch_idx):
# Define validation step. At the end of every epoch, this will be executed
= self.forward(batch, mode="valid")
out = F.mse_loss(out * self.std, batch.y * self.std) # report MSE
loss self.log("Valid MSE", loss)
def test_step(self, batch, batch_idx):
# What to do in test
= self.forward(batch, mode="test")
out = F.mse_loss(out * self.std, batch.y * self.std) # report MSE
loss self.log("Test MSE", loss)
def configure_optimizers(self):
# Here we configure the optimization algorithm.
= torch.optim.Adam(
optimizer self.parameters(),
=self.lr
lr
)return optimizer
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.valid_data, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False)
Here, we can use InMemoryDataset class in PyG to create the graph dataset of ESOL conveniently. You can also browse its tutorial and pre-defined dataset to learn about how to create graph datasets quickly by PyG.
from tqdm import tqdm
import pandas as pd
import torch
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
)from ogb.utils import smiles2graph
class ESOLGraphData(InMemoryDataset):
"""The ESOL graph dataset using PyG
"""
# ESOL dataset download link
= 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv'
raw_url
def __init__(self, root, transform=None):
super().__init__(root, transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['delaney-processed.csv']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
print('Downloading ESOL dataset...')
= download_url(self.raw_url, self.raw_dir)
file_path
def process(self):
# load raw data from a csv file
= pd.read_csv(self.raw_paths[0])
df = df['smiles'].values.tolist()
smiles = df['measured log solubility in mols per litre'].values.tolist()
target
# Convert SMILES into graph data
print('Converting SMILES strings into graphs...')
= []
data_list for i, smi in enumerate(tqdm(smiles)):
# get graph data from SMILES
= smiles2graph(smi)
graph
# convert to tensor and pyg data
= torch.tensor(graph['node_feat'], dtype=torch.long)
x = torch.tensor(graph['edge_index'], dtype=torch.long)
edge_index = torch.tensor(graph['edge_feat'], dtype=torch.long)
edge_attr = torch.tensor([target[i]], dtype=torch.float)
y = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data
data_list.append(data)
# save data
self.collate(data_list), self.processed_paths[0]) torch.save(
Create, normalize and split ESOL graph dataset.
from deepchem.splits import RandomSplitter
# create dataset
= ESOLGraphData('./esol_pyg').shuffle()
dataset
# Normalize target to mean = 0 and std = 1.
= dataset.data.y.mean()
mean = dataset.data.y.std()
std = (dataset.data.y - mean) / std
dataset.data.y = mean.item(), std.item()
mean, std
# split data
= RandomSplitter()
splitter = splitter.split(dataset, frac_train=0.7, frac_valid=0.1, frac_test=0.2)
train_idx, valid_idx, test_idx = dataset[train_idx]
train_dataset = dataset[valid_idx]
valid_dataset = dataset[test_idx] test_dataset
# This will ask you to login to your wandb account
import wandb
="gnn-solubility",
wandb.init(project={
config"batch_size": 32,
"learning_rate": 0.001,
"hidden_size": 64,
"max_epochs": 60
})
Train and evaluate the model.
# Here we create an instance of our GNN.
# Play around with the hyperparameters!
= MPNN(
gnn_model =wandb.config["hidden_size"],
hidden_dim=1,
out_dim=std,
std=train_dataset,
train_data=valid_dataset,
valid_data=test_dataset,
test_data=wandb.config["learning_rate"],
lr=wandb.config["batch_size"]
batch_size
)
# Define trainer: How we want to train the model
= WandbLogger()
wandb_logger = pl.Trainer(
trainer = wandb.config["max_epochs"],
max_epochs = wandb_logger
logger
)
# Finally! Training a model :)
trainer.fit(=gnn_model,
model
)
# Now run test
= trainer.test(ckpt_path="best")
results
wandb.finish()
# Test RMSE
= results[0]["Test MSE"]
test_mse = test_mse ** 0.5
test_rmse print(f"\nMPNN model performance: RMSE on test set = {test_rmse:.4f}.\n")